That last notebook got messy so I'm making a new cleaner one that will enable me to sample over both emulators simultaneously.
In [1]:
import matplotlib
#matplotlib.use('Agg')
from matplotlib import pyplot as plt
%matplotlib inline
import seaborn as sns
sns.set()
In [2]:
import numpy as np
import h5py
from chainconsumer import ChainConsumer
#from corner import corner
from ast import literal_eval
from pearce.emulator import LemonPepperWet
from os import path
from scipy.linalg import inv
In [3]:
#fname = '/u/ki/swmclau2/des/PearceMCMC/HOD_vdf_rmin_None_HOD.hdf5'
fname = '/u/ki/swmclau2/des/PearceMCMC/UniverseMachine_wp_ds_rmin_0.5_HOD.hdf5'
In [4]:
f = h5py.File(fname, 'r')
In [5]:
tf = f.attrs['training_file']
fixed_params = literal_eval(f.attrs['fixed_params'])
if 'rmin' in fixed_params:
del fixed_params['rmin']
emu_hps = literal_eval(f.attrs['emu_hps'])
In [6]:
tf
Out[6]:
In [7]:
emus = []
for t in tf:
print t
emus.append(LemonPepperWet(t, fixed_params = fixed_params, hyperparams = emu_hps) )
In [8]:
chain_pnames = f.attrs['param_names']
In [9]:
n_walkers = f.attrs['nwalkers']
In [10]:
n_burn = 1000
chain = f['chain'][n_burn*n_walkers:, :]
In [11]:
rmin = eval(f.attrs['fixed_params'])['rmin'] if 'rmin' in f.attrs['fixed_params'] else 0
In [12]:
n_params = chain.shape[1] if len(chain.shape) > 1 else 1
In [13]:
print chain.shape, chain.shape[0]/n_walkers
In [14]:
c = ChainConsumer()
chain = chain.reshape((-1, n_walkers, chain.shape[1]))
chain = chain.reshape((-1, chain.shape[2]), order = 'F')
c.add_chain(chain, parameters=list(chain_pnames), walkers = n_walkers)
Out[14]:
In [15]:
#MAP = chain.mean(axis = 0)
summary = c.analysis.get_summary(chains=0)
MAP = np.array([summary[key][1] for key in chain_pnames])
print MAP
In [16]:
chain_pnames
Out[16]:
In [17]:
MAP_dict = dict(zip(chain_pnames, MAP))
In [18]:
#fixed_params = eval(f.attrs['chain_fixed_params'])
#fixed_params.update(eval(f.attrs['sim'])['cosmo_params'])
In [19]:
MAP_dict.update(fixed_params)
In [20]:
MAP_dict
Out[20]:
In [21]:
MAP_pred = np.hstack([emu.emulate_wrt_r(MAP_dict, r_bin_centers=emu.scale_bin_centers).squeeze() for emu in emus])
MAP_pred = MAP_pred.reshape((len(emus), -1))
In [22]:
hod_idxs = np.array(range(7, len(chain_pnames)))
cosmo_idxs = np.array(range(7))
In [23]:
if chain.shape[1] == 7:
cosmo_chain = chain
else:
hod_chain = chain[:,7:]#[:, hod_idxs]
cosmo_chain = chain[:,:7]#[:, cosmo_idxs]
In [24]:
from pearce.mocks import cat_dict
cosmo_params = {}
In [25]:
cosmo_params['boxno'] = 1
cosmo_params['realization'] = 0
In [26]:
cat = cat_dict['testbox'](**cosmo_params)#construct the specified catalog!
In [27]:
cpv = cat._get_cosmo_param_names_vals()
cat_val_dict = {key: val for key, val in zip(cpv[0], cpv[1])}
In [28]:
cat_val_dict
Out[28]:
In [29]:
true_param_dict = cat_val_dict.copy()
hod_params = {'alpha': 1.083, 'conc_gal_bias': 1.0, 'logM0': 13.2,'logM1': 14.2, 'sigma_logM': 0.2}#sim_info['hod_params']
#hod_params = {'alpha': 1.03887697, 'conc_gal_bias': 1.0, 'logM0': 11.41003864,\
# 'logM1': 14.56088772, 'sigma_logM': 0.44415644}#sim_info['hod_params']
#hod_params['mean_occupation_centrals_assembias_param1'] = 0.0
#hod_params['mean_occupation_satellites_assembias_param1'] =0.0
#hod_params['mean_occupation_centrals_assembias_slope1'] = 0.0
#hod_params['mean_occupation_satellites_assembias_slope1'] = 0.0
#hod_params['mean_occupation_centrals_assembias_corr1'] = 0.0
#hod_params['mean_occupation_satellites_assembias_corr1'] =0.0
#for pname in emu.get_param_names():
# if 'mean_occupation' in pname:
# hod_params[pname] = 0.0
for hp, hv in hod_params.iteritems():
if hp == 'logMmin':
continue
true_param_dict[hp] = hv
true_param_dict['conc_gal_bias'] = 1.0
true_pred = emu.emulate_wrt_r(true_param_dict)[0]
In [30]:
true_param_dict
Out[30]:
In [31]:
true_data = f['data'][()]#.flatten()
In [32]:
true_data = true_data.reshape((len(emus), -1))
In [33]:
plt.plot(emu.scale_bin_centers, true_data[0])
#plt.plot(emu.scale_bin_centers, true_data[1])
plt.loglog();
In [34]:
rbc = emu.scale_bin_centers#[-len(emu.scale_bin_centers):]
In [35]:
cov = f['cov'][()]
yerr = np.sqrt(np.diag(cov)).reshape((len(emus), -1))
In [36]:
true_data[0]
Out[36]:
In [37]:
yerr[0]/true_data[0]
Out[37]:
In [38]:
rbc
Out[38]:
In [39]:
def cov_to_corr(cov):
std = np.sqrt(np.diag(cov))
denom = np.outer(std, std)
return cov/denom
In [40]:
cmap = sns.diverging_palette(240, 10, n=7, as_cmap = True)
In [41]:
plt.imshow(cov_to_corr(cov ) , cmap=cmap, vmin = -1)
Out[41]:
In [42]:
true_data
Out[42]:
In [43]:
true_param_dict
Out[43]:
In [44]:
true_data[0]
Out[44]:
In [45]:
fig = plt.figure(figsize = (12,5))
for i in xrange(len(emus)):
plt.subplot(1,2,i+1)
true_pred = emus[i].emulate_wrt_r(true_param_dict).squeeze()
plt.errorbar(rbc, true_data[i], yerr=yerr[i], label = 'Data')
plt.plot(rbc, 10**MAP_pred[i], label = 'MAP')
plt.plot(rbc, (10**true_pred), label = 'Emu at Truth')
plt.loglog()
plt.legend(loc='best')
plt.show();
In [46]:
rmin_idxs = rbc > rmin
In [47]:
fig = plt.figure(figsize = (12,5))
for i in xrange(len(emus)):
plt.subplot(1,2,i+1)
true_pred = emus[i].emulate_wrt_r(true_param_dict).squeeze()
plt.errorbar(rbc[rmin_idxs], np.ones_like(true_data[i][rmin_idxs]), yerr=yerr[i][rmin_idxs]/true_data[i][rmin_idxs], label = 'Data')
plt.plot(rbc[rmin_idxs], (10**MAP_pred[i][rmin_idxs])/true_data[i][rmin_idxs], label = 'MAP')
plt.plot(rbc[rmin_idxs], (10**true_pred[rmin_idxs])/true_data[i][rmin_idxs], label = 'Emu at Truth')
plt.legend(loc='best')
plt.xscale('log')
plt.show();
In [48]:
emu.obs
Out[48]:
In [49]:
rbc
Out[49]:
In [50]:
npart_aemulus = 1400**3
npart_mdpl2 = 3840**3
downsample_factor = 1e-2
npart_aemulus_ds = npart_aemulus*downsample_factor
print npart_aemulus_ds, npart_aemulus_ds/npart_mdpl2
In [51]:
chain_pnames
Out[51]:
In [52]:
true_pred = np.hstack([emu.emulate_wrt_r(true_param_dict, r_bin_centers=emu.scale_bin_centers).squeeze() for emu in emus])
true_pred = true_pred.reshape((len(emus), -1))
In [53]:
cov_rmin_idxs = np.hstack([rmin_idxs for i in xrange(len(emus))])
In [54]:
print 'True Red. Chi2'
R = (10**true_pred[:, rmin_idxs].flatten()-true_data[:, rmin_idxs].flatten())
chi2 = R.T.dot(inv(cov[cov_rmin_idxs][:,cov_rmin_idxs])).dot(R)
dof = len(chain_pnames)
print chi2/dof
In [55]:
print 'Map Red. Chi2'
R = (10**MAP_pred[:, rmin_idxs].flatten()-true_data[:, rmin_idxs].flatten())
chi2 = R.T.dot(inv(cov[cov_rmin_idxs][:,cov_rmin_idxs])).dot(R)
dof = len(chain_pnames)
print chi2/dof
In [56]:
N = 6
cmap = sns.color_palette("BrBG_d", N)
In [57]:
emu.get_param_names()
Out[57]:
In [58]:
true_h = 0.677
In [59]:
h_factor = lambda h : (h/true_h)**(1)
In [60]:
fig = plt.figure(figsize=(15,6))
varied_pname = 'H0'
lower, upper = emu.get_param_bounds(varied_pname)
for i in xrange(len(emus)):
plt.subplot(1,2,i+1)
plt.errorbar(rbc[rmin_idxs], np.ones_like(true_data[i][rmin_idxs]),\
yerr=yerr[i][rmin_idxs]/true_data[i][rmin_idxs], label = 'Data')
#pred = emus[i].emulate_wrt_r(true_param_dict).squeeze()
pred = emus[i].emulate_wrt_r(MAP_dict).squeeze()
plt.plot(rbc[rmin_idxs], h_factor(MAP_dict['H0']/100)*(10**pred[rmin_idxs])/true_data[i][rmin_idxs], label = 'Truth', color = 'k')
plt.xscale('log')
for c, val in zip(cmap, np.linspace(lower, upper, N) ):
#param_dict = true_param_dict.copy()
param_dict= MAP_dict.copy()
param_dict[varied_pname] = val
for i in xrange(len(emus)):
plt.subplot(1,2,i+1)
pred = emus[i].emulate_wrt_r(param_dict).squeeze()
plt.plot(rbc[rmin_idxs],h_factor(val/100)*(10**pred[rmin_idxs])/true_data[i][rmin_idxs], label = '%.3f'%val, color = c)
#plt.plot(rbc[rmin_idxs], np.ones_like(pred[rmin_idxs]), color = c)
plt.legend(loc='best')
plt.show();
In [61]:
from pearce.mocks.kittens import TrainingBox
In [62]:
?? np.argmin
In [63]:
np.argmin(cat.cosmo_params['H0']), np.argmax(cat.cosmo_params['H0'])
Out[63]:
In [66]:
np.argsort(cat.cosmo_params['H0'])
Out[66]:
In [67]:
catlow, cathigh = TrainingBox(23), TrainingBox(17)
In [68]:
catlow.load(1.0, HOD='zheng07')
In [69]:
cathigh.load(1.0, HOD='zheng07')
In [70]:
catlow.populate(hod_params)
In [71]:
cathigh.populate(hod_params)
In [72]:
rbins = np.logspace(-1.0, 1.6, 19)
In [73]:
wp_low_orig = catlow.calc_wp(rbins)
In [74]:
wp_high_orig = cathigh.calc_wp(rbins)
In [75]:
plt.plot(rbc, wp_low_orig)
plt.plot(rbc, wp_high_orig)
plt.loglog();
#plt.xscale('log')
plt.show()
In [76]:
lowh = catlow.h
highh = cathigh.h
catlow.h = 1
cathigh.h =1
In [77]:
wp_low = catlow.calc_wp(rbins)
In [78]:
wp_high = cathigh.calc_wp(rbins)
In [85]:
h_factor = lambda h: h
In [103]:
plt.plot(rbc, wp_low, label = 'Target')
plt.plot(rbc, wp_high, label = 'Target High')
plt.plot(rbc*lowh, lowh*wp_low_orig)
plt.plot(rbc*highh, highh*wp_high_orig)
plt.loglog();
#plt.xscale('log')
plt.legend(loc='best')
plt.show()
In [ ]: